package com.gl.common.autoparts.utils;

import com.gl.common.autoparts.entity.TreeNode;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

public class TreeUtil {
    /**
     * 两层循环实现建树
     *
     * @param treeNodes 传入的树节点列表
     * @return
     */
    public static <T extends TreeNode> List<T> bulid(List<T> treeNodes, Object root) {

        List<T> trees = new ArrayList<T>();

        for (T treeNode : treeNodes) {

            if (root.equals(treeNode.getParentId())) {
                trees.add(treeNode);
            }

            for (T it : treeNodes) {
                if (it.getParentId() .equals(treeNode.getId()) ) {
                    if (treeNode.getChildren() == null) {
                        treeNode.setChildren(new ArrayList<TreeNode>());
                    }
                    treeNode.add(it);
                }
            }
        }
        return trees;
    }

    public static <T extends TreeNode> List<T> buildTree(List<T> treeNodes) {
        HashMap<String, T> hashMap = new HashMap<>(treeNodes.size());
        treeNodes.forEach(i -> {
            hashMap.put(i.getId(), i);
        });
        ArrayList<T> list = new ArrayList<>();
        for (int i = 0; i < treeNodes.size(); i++) {
            T t = treeNodes.get(i);
            String pid = t.getParentId();
            if (hashMap.get(pid) != null) {
                hashMap.get(pid).add(t);
            } else {
                list.add(t);
            }
        }
        return list;
    }
    /**
     * 使用递归方法建树
     *
     * @param treeNodes
     * @return
     */
    public static <T extends TreeNode> List<T> buildByRecursive(List<T> treeNodes,Object root) {
        List<T> trees = new ArrayList<T>();
        int num = 0;
        for (T treeNode : treeNodes) {
            num++;
            if (root.equals(treeNode.getParentId())) {
                treeNode.setSerialNumber(num+"");
                trees.add(findChildren(treeNode, treeNodes));
            }
        }
        return trees;
    }

    /**
     * 递归查找子节点
     *
     * @param treeNodes
     * @return
     */
    public static <T extends TreeNode> T findChildren(T treeNode, List<T> treeNodes) {
        int num = 0;
        for (T it : treeNodes) {

            if (treeNode.getId() .equals( it.getParentId())) {
                if (treeNode.getChildren() == null) {
                    treeNode.setChildren(new ArrayList<TreeNode>());
                }
                num++;
                it.setSerialNumber(treeNode.getSerialNumber()+"."+num);
                treeNode.add(findChildren(it, treeNodes));
            }
        }
        return treeNode;
    }

}
